"""
PINF for the TFP equation

p_t - 0.5*Delta_p + 2*nabla_p = 0, x:R^d, t:[0,1]
p(x, 0) = (2*pi)**(-d/2)*exp(-x**2/2)
the exact solution:
p(x,t) = (2*pi*(t+1))**(-d/2)*exp(-(x-2*t)**2/(2*(t+1)))

"""

import os
import time
import copy
import datetime
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

#from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint

torch.manual_seed(111)
torch.cuda.manual_seed_all(111)

def antiderivTanh(x):
    return torch.abs(x) + torch.log(1+torch.exp(-2.0*torch.abs(x)))

def derivTanh(x):
    return 1 - torch.pow(torch.tanh(x) , 2)

class ResNN(torch.nn.Module):
    def __init__(self, d, m, nTh=2):
        super().__init__()

        if nTh < 2:
            print("nTh must be an integer >= 2")
            exit(1)

        self.d = d
        self.m = m
        self.nTh = nTh
        self.layers = torch.nn.ModuleList([])
        self.layers.append(torch.nn.Linear(d + 1, m, bias=True)) 
        self.layers.append(torch.nn.Linear(m, m, bias=True)) # resnet layers
        for i in range(nTh-2):
            self.layers.append(copy.deepcopy(self.layers[1]))
        self.act = antiderivTanh
        self.h = 1.0 / (self.nTh-1) # step size for the ResNet

    def forward(self, x):
        x = self.act(self.layers[0].forward(x))

        for i in range(1,self.nTh):
            x = x + self.h * self.act(self.layers[i](x))

        return x

    
class u(torch.nn.Module):
    def __init__(self, nTh, m, d, r=10, alph=[1.0] * 6):
        """
            neural network approximating u:
            u(x,t) = w'*ResNet([x;t]) + 0.5*[x;t] * A'A * [x;t] + b'*[x;t] + c
        """
        super().__init__()

        self.m = m
        self.nTh = nTh
        self.d = d
        self.alph = alph

        r = min(r,d+1) # if number of dimensions is smaller than default r, use that

        self.A = torch.nn.Parameter(torch.zeros(r, d+1) , requires_grad=True)
        self.A = torch.nn.init.xavier_uniform_(self.A)
        self.c = torch.nn.Linear(d+1, 1, bias=True)  # b'*[x;t] + c
        self.w = torch.nn.Linear(m, 1, bias=False)

        self.N = ResNN(d, m, nTh=nTh)
        torch.nn.init.xavier_uniform_(self.w.weight)
        torch.nn.init.xavier_uniform_(self.c.weight)

    def forward(self, s):
        symA = torch.matmul(torch.t(self.A), self.A) # A'A
        return self.w(self.N(s)) + 0.5 * torch.sum(torch.mm(s, symA) * s , dim=1, keepdims=True) + self.c(s)


def trace_df_dz(f, z):
    sum_diag = 0.
    for i in range(f.shape[1]):
        sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()

    return sum_diag.contiguous()


class logp_net(torch.nn.Module):
    def __init__(self, nTh, m, dim, T, device):
        super().__init__()
        self.dim = dim
        self.T = T
        self.device = device
        self.lb = torch.tensor([[-5.0] * self.dim + [0.0]]).to(self.device)
        self.ub = torch.tensor([[5.0] * self.dim + [self.T]]).to(self.device)
        self.net = u(nTh=nTh, m=m, d=dim)
    
    def log_p0(self, x):
        # p0 = (2*torch.pi)**(-self.dim/2) * torch.exp(-xn**2/2)
        xn = torch.norm(x, dim=1, keepdim=True)
        logp0 = (-self.dim/2) * torch.log(torch.tensor(2)*torch.pi) - xn**2/2
        
        return logp0

    def forward(self, x, t):
        xt = torch.cat([x, t], dim=1)
        xt.requires_grad_(True)
        xt = 2.0 * (xt - self.lb) / (self.ub - self.lb) - 1.0
        out = self.net(xt)

        return self.log_p0(x) + t * out

    
class ODEs(torch.nn.Module):
    def __init__(self, nTh, m, dim, T, device):
        super().__init__()
        self.dim = dim
        self.T = T
        self.device = device
        self.mu = 2 * torch.ones(size=(1, dim), device=device)
        self.phi = logp_net(nTh, m, dim, T, device)
           
    def forward(self, t, states):
        x = states[0]
        logp_x = states[1]

        bs = x.shape[0]
        t = t.view(1, 1).repeat(bs, 1)

        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            t.requires_grad_(True)
            
            logp = self.phi(x, t)
            dlogp_dx = torch.autograd.grad(logp.sum(), x, create_graph=True)[0]

            dx_dt = self.mu.repeat(bs,1) - 0.5 * dlogp_dx
            dlogp_x_dt = - trace_df_dz(dx_dt, x).view(bs, 1)

        return (dx_dt, dlogp_x_dt)


def create_Xd(n, d, T, dim, device):
    x = torch.linspace(-5.0, 5.0, n, device=device)
    X = x.unsqueeze(1)
    for k in range(1, d):
        x_add = x.unsqueeze(1).repeat(1, n**k).reshape(n**(k+1), 1)
        X = torch.cat([X.repeat(n, 1), x_add], dim=1)
    ones = torch.ones(size=(n**d, dim-d), device=device)
    return torch.cat([X, 2*T*ones], dim=1)


class PINF(torch.nn.Module):
    def __init__(self, config):
        super(PINF, self).__init__()
        self.T = config['T']
        self.nTh = config['nTh']
        self.m = config['m']
        self.dim = config['dim']
        self.N_in = config['N_in']
        self.lr = config['lr']
        self.epoch = config['epoch']
        self.train_save_freq = config['train_save_freq']
        self.test_freq = config['test_freq']
        self.log_freq = config['log_freq']
        self.device = config['device']
        self.path = config['path']
        
        self.p_test = None
        
        self.p_x0 = torch.distributions.MultivariateNormal(
                loc=torch.zeros(self.dim, device=self.device),
                covariance_matrix=torch.eye(self.dim, device=self.device))

        self.model = ODEs(self.nTh, self.m, self.dim, self.T, self.device)
        self.opt_Adam = torch.optim.Adam(params=self.model.parameters(), lr=self.lr)
        
        self.Epoch = [i * self.train_save_freq for i in range(int(self.epoch / self.train_save_freq) + 1)]
        self.Epoch_test = [i * self.test_freq for i in range(int(self.epoch / self.test_freq) + 1)]

        self.loss = torch.nn.MSELoss()
        self.test_p_net = {'mae':[], 'mape':[], 'mse':[]}
        self.test_p_ode = {'mae':[], 'mape':[], 'mse':[]}
        self.Loss = []

    def p_true(self, x, t):
        ones = torch.ones(self.dim).to(self.device)
        xn = torch.norm(x-2*t*ones, dim=1, keepdim=True)
        p = (2*torch.pi*(t+1))**(-self.dim/2) * torch.exp(-xn**2/(2*(t+1)))
        return p
    
    def logp_pred(self, x, t):
        logp_x = self.model.phi(x, t)
        return logp_x
    
    def ode_pred(self, x, t):
        bs = x.shape[0]
        logp_zero = torch.zeros(bs, 1).type(torch.float32).to(self.device)
        ts = torch.tensor([t, 0.]).type(torch.float32).to(self.device)

        x_t, logp_diff_t = odeint(
            self.model,
            (x, logp_zero),
            ts,
            atol=1e-8,
            rtol=1e-8,
            method='rk4',
        )

        x0, delta_logp = x_t[-1], logp_diff_t[-1]

        logp_x = self.model.phi.log_p0(x0) - delta_logp
        
        return logp_x

    def train_PINF(self):
        print("Strat training!")
        start = time.time()
        for it in range(self.epoch+1):
            # Test
            if it % self.test_freq == 0:
                self.test()
            
            # Train
            train_start = time.time()
            train_loss = self.train_model()
            if it % self.train_save_freq == 0:
                self.Loss.append(train_loss)
            train_iteration_time = time.time() - train_start
            
            # Print
            if it % self.log_freq == 0:
                print('It: %d, Time: %.2f, Loss: %.3e' % (it, train_iteration_time * self.log_freq, train_loss))
            # Plot
            if it % 1000 == 0:
                self.plot_fig(it)

        elapsed = time.time() - start
        print('Training complete! Total time: %.2f h' % (elapsed/3600))
        
    def train_model(self):
        self.opt_Adam.zero_grad()
        
        # sampling
        x0 = self.p_x0.sample([self.N_in]).to(self.device)
        logp_t0 = self.p_x0.log_prob(x0).to(self.device)
        t = torch.rand(size=(1,1), dtype=torch.float, device=self.device, requires_grad=True) * self.T
        t_scalar = t.item()
        
        ts = torch.tensor([0., t_scalar]).type(torch.float32).to(self.device)

        x_t, logp_diff_t = odeint(
            self.model,
            (x0, logp_t0),
            ts,
            atol=1e-6,
            rtol=1e-6,
            method='rk4',
        )
        
        x_in, log_p_ode = x_t[-1], logp_diff_t[-1]
        t_in = t.repeat(self.N_in, 1)
        log_p_net = self.logp_pred(x_in, t_in)

        loss = self.loss(log_p_net, log_p_ode.unsqueeze(1))

        loss.backward()
        self.opt_Adam.step()

        return loss.item()

    def test(self):
        # test set
        if self.p_test is None:
            if self.dim == 1:
                self.x = torch.linspace(-5,5,50, device=self.device)
            else:
                self.x = create_Xd(50, 2, self.T, self.dim, self.device)            
            self.t = torch.tensor([[self.T]]).repeat(self.x.shape[0], 1).to(self.device)
            self.p_test = self.p_true(self.x, self.t)
        
        p_net = torch.exp(self.logp_pred(self.x, self.t))
        p_ode = torch.exp(self.ode_pred(self.x, self.T))
        self.cal_error(p_net, method='net')
        self.cal_error(p_ode)
    
    def cal_error(self, p_pred, method='ode'):
        mae = torch.abs(p_pred - self.p_test).mean().item()
        mape = torch.abs((p_pred - self.p_test)/self.p_test).mean().item()
        mse = self.loss(p_pred, self.p_test).item()
        
        if method == 'ode':
            self.test_p_ode['mae'].append(mae)
            self.test_p_ode['mape'].append(mape)
            self.test_p_ode['mse'].append(mse)
            print('Predict by solving ODEs:  MAE: %.3e, MAPE: %.3e, MSE: %.3e' % (mae, mape, mse))
        else:
            self.test_p_net['mae'].append(mae)
            self.test_p_net['mape'].append(mape)
            self.test_p_net['mse'].append(mse)
            print('Predict by p_net:  MAE: %.3e, MAPE: %.3e, MSE: %.3e' % (mae, mape, mse))
    
    def plot_fig(self, epoch):
        p_net = torch.exp(self.logp_pred(self.x, self.t))
        p_ode = torch.exp(self.ode_pred(self.x, self.T))
        p_true = self.p_test
        mape = torch.abs(p_ode - p_true)/p_true

        X = self.x.cpu().detach().numpy()
        p_net_data = p_net.cpu().detach().numpy()
        p_ode_data = p_ode.cpu().detach().numpy()
        p_true_data = p_true.cpu().detach().numpy()
        mape_data = mape.cpu().detach().numpy()

        fig = plt.figure(figsize=(22, 4), dpi=200)
        #fig.suptitle('PINF')
        ax1 = fig.add_subplot(1, 4, 1, projection='3d')
        ax1.plot_surface(X[:,0].reshape(50,50), X[:,1].reshape(50,50), p_true_data.reshape(50,50), cmap='rainbow')
        ax1.set_title('Ground truth')
        
        ax2 = fig.add_subplot(1, 4, 2, projection='3d')
        ax2.plot_surface(X[:,0].reshape(50,50), X[:,1].reshape(50,50), p_net_data.reshape(50,50), cmap='rainbow')
        ax2.set_title('$p_{net}$')

        ax3 = fig.add_subplot(1, 4, 3, projection='3d')
        ax3.plot_surface(X[:,0].reshape(50,50), X[:,1].reshape(50,50), p_ode_data.reshape(50,50), cmap='rainbow')
        ax3.set_title('$p_{ode}$')

        ax4 = fig.add_subplot(1, 4, 4)
        ax4.set_title('mape')
        cntr4 =ax4.contourf(X[:,0].reshape(50,50), X[:,1].reshape(50,50), mape_data.reshape(50,50), levels=np.linspace(0,np.max(mape_data),101), cmap='viridis')

        fig.colorbar(cntr4, ax=ax4)
        ax4.set_xlabel('$x_1$')
        ax4.set_ylabel('$x_2$')
        ax4.set_aspect('equal')
        ax4.set_xticks(np.linspace(-5, 5, 11))
        ax4.set_yticks(np.linspace(-5, 5, 11))

        plt.savefig(self.path + "/pred_true_" + str(epoch) + ".png", bbox_inches = 'tight')
        plt.close()


if __name__ == "__main__":

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Hyperparameters
    T = 1
    nTh = 5 # L+1
    m = 32
    dim = 10
    N_in = 2000
    lr = 1e-2
    epoch = 10000
    train_save_freq = 25
    test_freq = 250
    log_freq = 50
    
    path = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    if not os.path.exists(path):
        os.makedirs(path)
    
    config = {
        'T': T,
        'nTh':nTh,
        'm':m,
        'dim': dim,
        'N_in': N_in,
        'lr': lr,
        'epoch': epoch,
        'train_save_freq': train_save_freq,
        'test_freq': test_freq,
        'log_freq': log_freq,
        'device': device,
        'path':path
    }

    model = PINF(config).to(device)
    model.train_PINF()

    # Save
    netname = 'PINF.pth'
    torch.save(model, path + '/' + netname)

    # Loss
    plt.figure(figsize=(8, 6))
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.plot(model.Epoch, model.Loss)
    plt.savefig(path + '/Training Loss.png')
    plt.close()

    # Error
    plt.figure(figsize=(8, 6))
    plt.title('Testing error')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.yscale('log')
    plt.plot(model.Epoch_test, model.test_p_net['mae'], 'r', label='p_net')
    plt.plot(model.Epoch_test, model.test_p_ode['mae'], 'g', label='p_ode')
    plt.legend()
    plt.savefig(path + '/Prediction MAE.png')
    plt.close()
    
    plt.figure(figsize=(8, 6))
    plt.title('Testing error')
    plt.xlabel('Epoch')
    plt.ylabel('MAPE')
    plt.yscale('log')
    plt.plot(model.Epoch_test, model.test_p_net['mape'], 'r', label='p_net')
    plt.plot(model.Epoch_test, model.test_p_ode['mape'], 'g', label='p_ode')
    plt.legend()
    plt.savefig(path + '/Prediction MAPE.png')
    plt.close()
